# Author: Kunal Sangani, ksangani@g.harvard.edu

import numpy as np
from scipy.optimize import fsolve
import matplotlib.pyplot as plt
import pandas as pd
from scipy.special import gammainc
import seaborn as sns
import support_functions as sp
import tikzplotlib
from scipy.interpolate import interp1d
from scipy.optimize import minimize
import klenow_willis as kw

FIG_SIZE_1x2 = (9,4)
FIG_SIZE_2x1 = (5,8)
FIG_SIZE_2x2 = (10,8)


#########################################
# Atkenson-Burstein oligopoly models
#########################################

# Replication of Amiti et al 2019
# Note: All calculations done assuming Cournot style competition
def price_solver(prices, productivities, wages, sigma_across, sigma_within):
	# Solve using system of equations for prices:
	# P_i = [1 - 1/rho + s_i(1/rho - 1/eta)]^(-1) * W_i / z_i
	shares = calculate_shares(prices, sigma_within)
	prices_pred = (1 - 1/sigma_within + shares*(1/sigma_within - 1/sigma_across))**(-1) * wages / productivities
	return prices_pred - prices

# Note: All calculations done assuming Cournot style competition
def calculate_shares(prices, sigma_within):
	# CES at industry-level gives us s_i = (P_i)^(1-rho) / (\sum_j (P_j)^(1-rho))
	denom = sum(prices ** (1 - sigma_within))
	shares = (prices**(1-sigma_within)) / denom

	# If any shares exceed 1, then the optimal price will become zero or negative
	UPPER_BOUND = 0.99
	shares[shares >= UPPER_BOUND] = UPPER_BOUND
	return shares/sum(shares)

def calculate_markups(prices, productivities, wages):
	return prices / (wages / productivities)

def calculate_output(shares, prices, N_firms, sigma_within, sigma_across):
	y = shares / prices
	Y_sector = ((1/N_firms) * np.sum(y**((sigma_within - 1)/sigma_within), axis=1))**(sigma_within/(sigma_within-1))
	return (y.T/Y_sector).T

# Note: All calculations done assuming Cournot style competition
def calculate_passthroughs(prices, markups, sigma_across, sigma_within):
	# Passthrough = 1 / (1 - dlog markup / dlogp)
	# dlog markup / dlogp = markup * (1/eta - 1/rho) * ds/dlogp
	denom = sum(prices ** (1 - sigma_within))
	ds_dlogp = (1-sigma_within)*(prices**(-sigma_within)) / denom**2 * (denom - prices**(1-sigma_within)) * prices
	dmarkup_dlogp = markups * (1/sigma_across - 1/sigma_within) * ds_dlogp
	return 1/(1 - dmarkup_dlogp)

def draw_firms_amiti(n_sectors, N_firms, pareto_scale=8):
	# Amiti et al version: Pareto with shape = 8 and lower bound = exp(0)
	firms = np.random.pareto(pareto_scale, (n_sectors, N_firms)) + 1
	return firms

def draw_firms_ab(n_sectors, N_firms, theta=0.38, sigma=0.20):
	# Atkeson-Burstein version: Lognormal w/ params theta = 0.38, sigma = 0.20
	firms = np.exp(np.random.normal(loc=0, scale=theta, size=(n_sectors, N_firms)) + np.random.normal(loc=0, scale=sigma, size=(n_sectors,1)))
	return firms

def oligopoly_sim(n_sectors=1000, sigma_across=1, sigma_within=10, N_firms=45, replication='ab'):
	# Draw firm productivities (either 'amiti' or 'ab')
	firms = draw_firms_amiti(n_sectors=n_sectors, N_firms=N_firms)
	if replication == 'ab':
		firms = draw_firms_ab(n_sectors=n_sectors, N_firms=N_firms)

	# Set W_i = 1 for all firms
	wages = np.ones((n_sectors, N_firms))

	# Initialize prices, markups, shares, passthroughs
	prices = (sigma_within / (sigma_within-1)) * np.ones((n_sectors, N_firms))
	shares = np.ones((n_sectors, N_firms))
	markups = np.ones((n_sectors, N_firms))
	passthroughs = np.ones((n_sectors, N_firms))

	for i in range(n_sectors):
		prices[i,:] = fsolve(price_solver, prices[i,:], args=(firms[i,:], wages[i,:], sigma_across, sigma_within))
		shares[i,:] = calculate_shares(prices[i,:], sigma_within)
		markups[i,:] = calculate_markups(prices[i,:], firms[i,:], wages[i,:])
		passthroughs[i,:] = calculate_passthroughs(prices[i,:], markups[i,:], sigma_across, sigma_within)
		# If any markups are negative, print for troubleshooting. 
		# (This happens sometimes with the lognormal AB params.)
		if(min(markups[i,:]) < 1):
			print(i)
			print(firms[i,:])
			print(shares[i,:])
			print(markups[i,:])

	return firms, prices, shares, markups, passthroughs

def effective_passthroughs(shares, passthroughs, calvo):
	# Let alpha = pi * (1-rho)/(1-shares) * shares
	# num_1 = sum_(j in I)[shares * (1-calvo) / (1+alpha)]
	# num_2 = sum_(j in I)[alpha / (1+alpha)]
	alpha = calvo * (1 - passthroughs) * shares / (1 - shares)
	num_1 = np.sum(shares * (1 - calvo) / (1 + alpha), axis=1)
	num_2 = np.sum(alpha / (1 + alpha), axis=1)
	# Eff pass = [alpha/shares * num_1 / (1 - num_2) + (1-pi)] / (1 + alpha)
	num_3 = (alpha.T * (num_1 / (1-num_2))).T
	return 1 - ((num_3 / shares + (1-calvo)) / (1 + alpha))

def philips_slopes(shares, markups, passthroughs, frisch=0.5, inc_elast=0.5, pi=0.20, sigma_within=10):
	share_flat = shares.flatten()
	pass_flat = passthroughs.flatten()
	#
	calvo = pi * np.ones(shares.shape)
	calvo_flat = calvo.flatten()
	#
	eff_pass = effective_passthroughs(shares, passthroughs, calvo)
	eff_pass_flat = eff_pass.flatten()
	#
	inv_sector_markups = np.sum(1/markups * shares, axis=1)
	inv_mark_sct_wt_flat = ((1/markups).T / inv_sector_markups).T.flatten()
	#	
	# All slopes are calculated as (dlog Y / dlog w). Will invert all slopes at end to get Philips slope.
	# Slope if just sticky prices
	calvo_channel = frisch / (1 + inc_elast) * sp.weighted_exp(1-calvo_flat, share_flat)
	calvo_channel_price = calvo_channel / (1 - sp.weighted_exp(1-calvo_flat, share_flat))
	# Slope with real rigidities
	real_rigidities_channel = frisch / (1 + inc_elast) * sp.weighted_exp(1-eff_pass_flat, share_flat)
	real_rigid_channel_price = real_rigidities_channel / (1 - sp.weighted_exp(1-eff_pass_flat, share_flat))
	# Slope with misallocation channel
	philips_wage_slope = real_rigidities_channel + (1/(1+inc_elast)) * sigma_within * sp.weighted_cov(1-inv_mark_sct_wt_flat, 1-eff_pass_flat, share_flat)
	philips_price_slope = real_rigid_channel_price + (1/(1+inc_elast)) * sigma_within * sp.weighted_cov(1-inv_mark_sct_wt_flat, 1-eff_pass_flat, share_flat) / (1 - sp.weighted_exp(1-eff_pass_flat, share_flat))
	#
	return 1/calvo_channel,1/real_rigidities_channel,1/philips_wage_slope,1/calvo_channel_price, 1/real_rigid_channel_price,1/philips_price_slope

def compare_passthrough_to_amiti_empirics(firms, prices, shares, markups, passthroughs):
	# Simple figure of passthrough by sales share
	# Amiti estimates are from Table 2
	plt.figure()
	plt.scatter(shares.flatten(), passthroughs.flatten(), label='$\\rho_i^{flex}$')
	plt.xlabel('Sales share')
	plt.ylabel('$\\rho_i^{flex}$')
	plt.title('$\\rho_i^{flex}$ compared to Amiti et al estimates')
	plt.axhline(sp.weighted_exp(passthroughs.flatten(), shares.flatten()), linewidth=1, linestyle='--', color='black', label='Sales-weighted average (model)')
	plt.axhline(0.478, linewidth=1, linestyle='-.', color='red', label='Large firms estimate (Amiti et al)')
	plt.axhline(0.972, linewidth=1, linestyle=':', color='red', label='Small firms estimate (Amiti et al)')
	plt.legend()

	# Comparison to figure A1 from Amiti et al
	# Sales-weighted average passthrough for smallest firms under X share cutoff
	total_share_grid = np.linspace(0.05, 1)
	firms_ordered = pd.DataFrame({
		'share' : shares.flatten(),
		'passthrough' : passthroughs.flatten()
		})
	firms_ordered = firms_ordered.sort_values(by='share').reset_index()
	firms_ordered['cumul_share'] = firms_ordered['share'].cumsum() / sum(firms_ordered['share'])
	weighted_passthrough_grid = np.zeros(len(total_share_grid))
	for i,c_share in enumerate(total_share_grid):
		subset_small_firms = firms_ordered.loc[firms_ordered['cumul_share'] <= c_share]
		weighted_passthrough_grid[i] = sp.weighted_exp(subset_small_firms['passthrough'], subset_small_firms['share'])

	# Amiti et al data from Figure A2 (approximate)
	amiti_A2 = pd.DataFrame({
		'share' : [0.23, 0.35, 0.41, 0.48, 0.54, 0.58, 0.67, 1],
		'passthrough_est' : [0.98, 0.88, 0.92, 0.82, 0.65, 0.68, 0.61, 0.63],
		'pt_c95' : [1.38, 1.2, 1.22, 1.1, 0.99, 1.05, 1.01, 0.88],
		'pt_c05' : [0.66, 0.6, 0.68, 0.62, 0.35, 0.40, 0.37, 0.42]
		})
	amiti_A2['yerr_top'] = amiti_A2['pt_c95']-amiti_A2['passthrough_est']
	amiti_A2['yerr_btm'] = amiti_A2['passthrough_est']-amiti_A2['pt_c05']

	plt.figure()
	plt.plot(total_share_grid, weighted_passthrough_grid, label='Model')
	plt.xlabel('Share of sales below size threshold')
	plt.ylabel('Sales-weighted average passthrough')
	plt.errorbar(amiti_A2['share'], amiti_A2['passthrough_est'], 
		yerr=np.asarray(amiti_A2[['yerr_btm', 'yerr_top']]).T, fmt='--o', 
		capsize=2, label='Empirical est (Amiti et al, Figure A2)')
	plt.xlim((0.1,1.05))
	plt.ylim((0, 1.4))
	plt.legend()

def plot_darwinian_distributions(firms, shares, markups, passthroughs, sort_by='share'):
	all_firms = pd.DataFrame({
		'productivity' : firms.flatten(),
		'share': shares.flatten(),
		'markup': markups.flatten(),
		'passthrough': passthroughs.flatten()})
	all_firms = all_firms.sort_values(by=sort_by).reset_index()
	all_firms['norm_rank'] = all_firms.index / len(all_firms); 
	fig,axs = plt.subplots(2,2, figsize=FIG_SIZE_2x2)
	fig.suptitle('Distributions (firms ordered by ' + sort_by + ')')
	axs[0][0].scatter(all_firms['norm_rank'], all_firms['passthrough'], alpha=0.2, marker='.')
	axs[0][0].set_title('Passthrough ($\\rho_i^{flex}$)')
	axs[0][1].scatter(all_firms['norm_rank'], np.log(all_firms['share']), alpha=0.2, marker='.')
	axs[0][1].set_title('Log sales share ($\\log \\lambda_i$)')
	axs[1][0].scatter(all_firms['norm_rank'], all_firms['markup'], alpha=0.2, marker='.')
	axs[1][0].set_title('Markup ($\\mu_i$)')
	axs[1][1].scatter(all_firms['norm_rank'], all_firms['productivity'], alpha=0.2, marker='.')
	axs[1][1].set_title('Productivity ($A_i$)')
	plt.tight_layout(rect=[0, 0, 1, 0.95])

	plt.figure()
	plt.plot(all_firms['norm_rank'], all_firms['passthrough'])
	tikzplotlib.clean_figure()
	tikzplotlib.save('../Draft/figures/oligopoly_passthroughs.tex', extra_axis_parameters=['PlotStyle'])

	plt.figure()
	plt.plot(all_firms['norm_rank'], all_firms['markup'])
	tikzplotlib.clean_figure()
	tikzplotlib.save('../Draft/figures/oligopoly_markups.tex', extra_axis_parameters=['PlotStyle'])


def amiti_replication_plots(firms, prices, shares, markups, passthroughs):
	_, _, shares2, markups2, passthroughs2 = oligopoly_sim(sigma_across=2)
	_, _, shares3, markups3, passthroughs3 = oligopoly_sim(sigma_within=5)
	hhi = np.sum(shares ** 2, axis=1)
	largest_firm = np.max(shares, axis=1)
	
	plt.figure()
	plt.hist(hhi, bins=20, weights=(np.ones(len(hhi)) / len(hhi)))
	plt.axvline(np.median(hhi), color='black', linewidth=1, linestyle='--', label='Median')
	plt.axvline(np.mean(hhi), color='black', linewidth=1, linestyle='-.', label='Average')
	plt.xlabel('HHI')
	plt.ylabel('Frequency')
	plt.title('Distribution of industry HHI with Amiti parameters')
	plt.legend()

	plt.figure()
	plt.hist(largest_firm, bins=20, weights=(np.ones(len(largest_firm)) / len(largest_firm)))
	plt.axvline(np.median(largest_firm), color='black', linewidth=1, linestyle='--', label='Median')
	plt.axvline(np.mean(largest_firm), color='black', linewidth=1, linestyle='-.', label='Average')
	plt.xlabel('Market share of largest firm')
	plt.ylabel('Frequency')
	plt.title('Distribution of market share of largest firm (Amiti parameters)')
	plt.legend()

	plt.figure()
	cumul_shares = np.sort(shares.flatten())[::-1]
	for i in range(len(cumul_shares) - 1):
		cumul_shares[i + 1] = cumul_shares[i] + cumul_shares[i + 1]
	cumul_shares = cumul_shares / cumul_shares[-1]
	plt.plot(np.linspace(0,1,len(cumul_shares)), cumul_shares)
	plt.axvline(0.20, color='black', linewidth=1, linestyle='--')
	plt.axhline(0.60, color='black', linewidth=1, linestyle='--')
	plt.xlabel('Percent of firms')
	plt.ylabel('Cumulative share')
	plt.title('Replication of 20% of firms have 60% of market share')

	fig,axs = plt.subplots(1,2, figsize=FIG_SIZE_1x2)
	axs[0].scatter(shares.flatten(), markups.flatten(), label='Baseline')
	axs[0].scatter(shares2.flatten(), markups2.flatten(), label='$\\eta = 2$')
	axs[0].scatter(shares3.flatten(), markups3.flatten(), label='$\\rho=5$')
	axs[0].set_xlim((0, 0.25))
	axs[0].set_xlabel('Market share')
	axs[0].set_ylabel('Markup')
	axs[0].legend()

	axs[1].scatter(shares.flatten(), passthroughs.flatten(), label='Baseline')
	axs[1].scatter(shares2.flatten(), passthroughs2.flatten(), label='$\\eta = 2$')
	axs[1].scatter(shares3.flatten(), passthroughs3.flatten(), label='$\\rho=5$')
	axs[1].set_xlim((0, 0.25))
	axs[1].set_xlabel('Market share')
	axs[1].set_ylabel('Own cost passthrough')
	axs[1].legend()

	plt.tight_layout(rect=[0, 0, 1, 0.95])
	plt.suptitle('Replication of Figure A3 from Amiti et al (2019)')

#########################################
# Kimball monopolistic competitition models
#########################################

def philips_slopes_kimball(shares, markups, passthroughs, frisch=0.2, inc_elast=0.2, pi=0.50):
	calvo = pi * np.ones(shares.shape)
	elast = 1/(1 - (1/markups))
	sharemu = shares/markups
	# E_\lambda [ d(log \mu_\theta) / dw ]
	dmu_dw = sp.weighted_exp(calvo * (1-passthroughs), shares) * sp.weighted_exp(elast * (1 - calvo), shares) / sp.weighted_exp((calvo*passthroughs + (1 - calvo))*elast, shares)
	dmu_dw = -(dmu_dw + sp.weighted_exp(1 - calvo, shares))
	dmu_dw_norealrigid = -sp.weighted_exp(1 - calvo, shares)
	# dlogA/dw
	dA_dw = sp.weighted_exp((1-calvo)*elast, shares) * sp.weighted_exp(calvo * passthroughs * elast, sharemu)
	dA_dw = dA_dw - sp.weighted_exp((1-calvo)*elast, sharemu) * sp.weighted_exp(calvo * passthroughs * elast, shares)
	dA_dw = dA_dw / sp.weighted_exp(elast*(calvo*passthroughs + (1-calvo)), shares)
	# All slopes are calculated as (dlog Y / dlog w). Will invert all slopes at end to get Philips slope.
	# Slope if just sticky prices
	calvo_channel = 1/(1+inc_elast) * (-frisch * dmu_dw_norealrigid)
	calvo_channel_price = calvo_channel / (1 + dmu_dw_norealrigid)
	# Slope with real rigidities
	real_rigidities_channel =  1/(1+inc_elast) * (-frisch * dmu_dw)
	real_rigid_channel_price = real_rigidities_channel / (1 + dmu_dw)
	# Slope with misallocation channel
	philips_wage_slope = real_rigidities_channel + (1/(1+inc_elast)) * dA_dw
	philips_price_slope = real_rigid_channel_price + (1/(1+inc_elast)) * dA_dw / (1 + dmu_dw)
	return 1/calvo_channel,1/real_rigidities_channel,1/philips_wage_slope,1/calvo_channel_price, 1/real_rigid_channel_price,1/philips_price_slope

def compare_passthrough_distributions(output_oligopoly, passthroughs_oligopoly, output_dar, passthroughs_dar):
	shares_grid = np.linspace(0.0001,max(output_dar),1000)
	plt.figure()
	plt.scatter(output_oligopoly.flatten(), passthroughs_oligopoly.flatten(), label='Oligopoly model', alpha=0.3, marker='.')
	plt.plot(*calculate_klenow_willis_passthroughs(shares_grid, 5, 0), label='CES baseline ($\\bar{\\theta} = 5, \\epsilon=0$)')
	plt.plot(*calculate_klenow_willis_passthroughs(shares_grid, 5, 1.6), label='Amiti et al calibration ($\\bar{\\theta} = 5, \\epsilon=1.6$)')
	plt.plot(*calculate_klenow_willis_passthroughs(shares_grid, 5, 10), label='Klenow-Willis 2016 ($\\bar{\\theta} = 5, \\epsilon=10$)')
	plt.plot(output_dar, passthroughs_dar, label='Darwinian data')
	plt.xlabel('$y_i / Y$')
	plt.ylabel('Passthrough ($\\rho_i^{flex}$)')
	plt.title('Comparison of $\\rho_i$ distributions from various models')
	plt.legend(loc='upper right')

#########################################
# Comparative statics plots: Varying Frisch, pi, average markup
#########################################

# Plots slope of Philips curve, pinning the income/Frisch elasticities and letting pi (Calvo parameter) vary
def vary_pi(shares, markups, passthroughs, frisch_fix=0.2, inc_elast_fix=0.2, philips_fn=philips_slopes, savefile=None):
	pi_space = [0.001, 0.01, 0.02, 0.03, 0.05, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
	calvo_wage_slopes = np.zeros(len(pi_space))
	real_rigid_wage_slopes = np.zeros(len(pi_space))
	phil_wage_slopes = np.zeros(len(pi_space))
	calvo_price_slopes = np.zeros(len(pi_space))
	real_rigid_price_slopes = np.zeros(len(pi_space))
	phil_price_slopes = np.zeros(len(pi_space))
	for i,pi_est in enumerate(pi_space):
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_fn(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_est)
	sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, pi_space, 'Price flexibility ($\\delta$)', savefile)
	
# Plots slope of Philips curve, pinning the Calvo parameter and letting Frisch/income elasticities vary
def vary_frisch(shares, markups, passthroughs, pi_fix=0.5, philips_fn=philips_slopes, savefile=None):
	frisch_space = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9, 1, 1.25, 1.5, 2]
	calvo_wage_slopes = np.zeros(len(frisch_space))
	real_rigid_wage_slopes = np.zeros(len(frisch_space))
	phil_wage_slopes = np.zeros(len(frisch_space))
	calvo_price_slopes = np.zeros(len(frisch_space))
	real_rigid_price_slopes = np.zeros(len(frisch_space))
	phil_price_slopes = np.zeros(len(frisch_space))
	for i,frisch_est in enumerate(frisch_space):
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_fn(shares, markups, passthroughs, frisch=frisch_est, inc_elast=frisch_est, pi=pi_fix)
	sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, frisch_space, 'Frisch elasticity ($\\zeta$)', savefile)
	
def sales_density_response(shares, mu, rho, pi=0.5):
	elast = 1/(1-1/mu)
	E_elast = sp.weighted_exp(elast, shares)
	E_rho = sp.weighted_exp(rho, shares)
	E_elastrho = sp.weighted_exp(elast * rho, shares)
	dlogP = pi * E_elastrho / (pi * E_elastrho + (1-pi)*E_elast)
	dlogPY = - pi*(1-pi)*(1-E_rho)*E_elast / (pi * E_elastrho + (1-pi)*E_elast) + pi
	print('dlogP', dlogP)
	print('dlogPY', dlogPY)
	dlog_shares = (1-elast)*pi*((1-rho)*dlogP + rho) + elast*dlogP - dlogPY
	plt.plot(np.linspace(0, 1, len(dlog_shares)),dlog_shares)
	plt.xlabel('$\\theta$')
	plt.ylabel('$d\\log \\lambda_\\theta / d\\log w$')
	plt.figure()
	plt.plot(np.linspace(0, 1, len(dlog_shares)),np.log(elast))
	plt.xlabel('$\\theta$')
	plt.ylabel('$\\log(\\sigma_\\theta)$')
	plt.show()
	dtheta = 1/sum(shares)
	hhi = sum(shares**2) * dtheta
	dlog_hhi = 2/hhi * sum(dlog_shares * shares**2) * dtheta
	print('hhi: ', hhi)
	print('dlog_hhi: ', dlog_hhi)
	gini = 1 - 2*sum(np.cumsum(shares) * dtheta) * dtheta
	dlog_gini = -2/gini * sum(np.cumsum(shares * dlog_shares) * dtheta) * dtheta
	print('gini: ', gini)
	print('dlog_gini: ', dlog_gini)

def sales_density_response_plot(shares, mu, rho, label='', pi=0.5):
	elast = 1/(1-1/mu)
	E_elast = sp.weighted_exp(elast, shares)
	E_rho = sp.weighted_exp(rho, shares)
	E_elastrho = sp.weighted_exp(elast * rho, shares)
	dlogP = pi * E_elastrho / (pi * E_elastrho + (1-pi)*E_elast)
	dlogPY = - pi*(1-pi)*(1-E_rho)*E_elast / (pi * E_elastrho + (1-pi)*E_elast) + pi
	print('dlogP', dlogP)
	print('dlogPY', dlogPY)
	dlog_shares = (1-elast)*pi*((1-rho)*dlogP + rho) + elast*dlogP - dlogPY
	plt.plot(np.linspace(0, 1, len(dlog_shares)),dlog_shares,label=label)
	plt.xlabel('$\\theta$')
	plt.ylabel('$d\\log \\lambda_\\theta / d\\log w$')
	# plt.figure()
	# plt.plot(np.linspace(0, 1, len(dlog_shares)),np.log(elast))
	# plt.xlabel('$\\theta$')
	# plt.ylabel('$\\log(\\sigma_\\theta)$')
	# plt.show()
	# dtheta = 1/sum(shares)
	# hhi = sum(shares**2) * dtheta
	# dlog_hhi = 2/hhi * sum(dlog_shares * shares**2) * dtheta
	# print('hhi: ', hhi)
	# print('dlog_hhi: ', dlog_hhi)
	# gini = 1 - 2*sum(np.cumsum(shares) * dtheta) * dtheta
	# dlog_gini = -2/gini * sum(np.cumsum(shares * dlog_shares) * dtheta) * dtheta
	# print('gini: ', gini)
	# print('dlog_gini: ', dlog_gini)

def sales_density_cutoffs(shares, mu, rho, pi=0.5):
	elast = 1/(1-1/mu)
	E_elast = sp.weighted_exp(elast, shares)
	E_rho = sp.weighted_exp(rho, shares)
	E_elastrho = sp.weighted_exp(elast * rho, shares)
	dlogP = pi * E_elastrho / (pi * E_elastrho + (1-pi)*E_elast)
	dlogPY = - pi*(1-pi)*(1-E_rho)*E_elast / (pi * E_elastrho + (1-pi)*E_elast) + pi
	print('dlogP', dlogP)
	print('dlogPY', dlogPY)
	dlog_shares = (1-elast)*pi*((1-rho)*dlogP + rho) + elast*dlogP - dlogPY
	dlog_share_below = np.zeros(len(dlog_shares))
	dlog_share_above = np.zeros(len(dlog_shares))
	for i in range(len(dlog_shares)):
		dlog_share_below[i] = np.sum(dlog_shares[:i]*shares[:i])/np.sum(shares[:i])
		dlog_share_above[i] = np.sum(dlog_shares[i:]*shares[i:])/np.sum(shares[i:])
	plt.plot(np.linspace(0,1,len(dlog_share_below)), dlog_share_below,label='dlog(Cum. Sales Share Below $\\theta$)')
	plt.plot(np.linspace(0,1,len(dlog_share_above)), dlog_share_above,label='dlog(Cum. Sales Share Above $\\theta$)')
	plt.legend()
	plt.xlabel('$\\theta$')

def sales_density_by_agg_markup(shares, mu, rho, pi=0.5):
	mu_bar_space = [1.01, 1.02, 1.045, 1.09, 1.12, 1.15, 1.30, 1.40, 1.50, 1.60]
	markups_gen = [sp.generate_markups(rho, shares, mu_bar) for mu_bar in mu_bar_space]
	for markups,mu_bar in zip(markups_gen, mu_bar_space):
		print(mu_bar)
		sales_density_response_plot(shares, markups, rho, label='$\\bar\\mu = ' + str(mu_bar) + '$', pi=pi)
	plt.legend()
	plt.show()

def vary_mu_bar(shares, passthroughs, frisch_fix=0.2, inc_elast_fix=0.2, pi_fix=0.5, philips_fn=philips_slopes, savefile=None):
	mu_bar_space = [1.001, 1.01, 1.02, 1.045, 1.06, 1.09, 1.12, 1.15, 1.20, 1.25, 1.30, 1.35, 1.40, 1.45, 1.50, 1.55, 1.60]
	markups_gen = [sp.generate_markups(passthroughs, shares, mu_bar) for mu_bar in mu_bar_space]
	
	calvo_wage_slopes = np.zeros(len(mu_bar_space))
	real_rigid_wage_slopes = np.zeros(len(mu_bar_space))
	phil_wage_slopes = np.zeros(len(mu_bar_space))
	calvo_price_slopes = np.zeros(len(mu_bar_space))
	real_rigid_price_slopes = np.zeros(len(mu_bar_space))
	phil_price_slopes = np.zeros(len(mu_bar_space))
	for i,markups in enumerate(markups_gen):
		calvo_wage_slopes[i],real_rigid_wage_slopes[i],phil_wage_slopes[i],calvo_price_slopes[i],real_rigid_price_slopes[i],phil_price_slopes[i] = philips_fn(shares, markups, passthroughs, frisch=frisch_fix, inc_elast=inc_elast_fix, pi=pi_fix)
	sp.plot_slopes(calvo_wage_slopes, real_rigid_wage_slopes, phil_wage_slopes, calvo_price_slopes, real_rigid_price_slopes, phil_price_slopes, mu_bar_space, 'Average markup ($\\bar\\mu$)', savefile)

#########################################
# Driver
#########################################

def kw_version():
	# Darwinian returns to scale data
	shares,markups,passthroughs = kw.generate_kw_distributions(superelast=1.6)
	elast = 1/(1-1/markups)
	print(shares.shape, markups.shape, passthroughs.shape, elast.shape)
	print('elast', sp.weighted_exp(elast,shares))
	print('mu', sp.weighted_exp(markups,shares))
	print('rho', sp.weighted_exp(passthroughs,shares))
	#sales_density_response(shares_dar, markups_dar, passthroughs_dar, pi=0.5)
	# Slope of the Philips curve, varying parameters in turn
	vary_pi(shares, markups, passthroughs, frisch_fix=0.2, inc_elast_fix=0.2, philips_fn=philips_slopes_kimball,savefile=None)
	vary_frisch(shares, markups, passthroughs, philips_fn=philips_slopes_kimball,savefile=None)
	#vary_mu_bar(shares, passthroughs, philips_fn=philips_slopes_kimball,savefile=None)
	print(philips_slopes_kimball(shares,markups,passthroughs))
	plt.show()


def tikz_plots_only():
	# Darwinian returns to scale data
	shares_dar,markups_dar,passthroughs_dar,prod_dar,output_dar = sp.get_key_darwinian_data(pull_col=2)
	markups_dar = sp.generate_markups(passthroughs_dar, shares_dar, 1.15)
	elast = 1/(1-1/markups_dar)
	print('elast', sp.weighted_exp(elast,shares_dar))
	print('mu', sp.weighted_exp(markups_dar,shares_dar))
	print('rho', sp.weighted_exp(passthroughs_dar,shares_dar))
	#sales_density_response(shares_dar, markups_dar, passthroughs_dar, pi=0.5)
	# Slope of the Philips curve, varying parameters in turn
	savefile = '../Draft/figures/belgian_calvo_v2.tex'
	vary_pi(shares_dar, markups_dar, passthroughs_dar, frisch_fix=0.2, inc_elast_fix=0.2, philips_fn=philips_slopes_kimball,savefile=savefile)
	savefile = '../Draft/figures/belgian_frisch_v2.tex'
	vary_frisch(shares_dar, markups_dar, passthroughs_dar, philips_fn=philips_slopes_kimball,savefile=savefile)
	savefile = '../Draft/figures/belgian_mu_bar_v2.tex'
	vary_mu_bar(shares_dar, passthroughs_dar, philips_fn=philips_slopes_kimball,savefile=savefile)
	#print(philips_slopes_kimball(shares_dar,markups_dar,passthroughs_dar))
	plt.show()

def driver():
	#########################################
	# Oligopoly simulation (parameters from Amiti et al)
	firms, prices, shares, markups, passthroughs = oligopoly_sim(n_sectors=1000, replication='amiti')
	# Plots to replicate Amiti et al appendix
	#amiti_replication_plots(firms, prices, shares, markups, passthroughs)
	# Slope of the Philips curve, varying parameters in turn
	vary_pi(shares, markups, passthroughs, frisch_fix=0.2, inc_elast_fix=0.2)
	vary_frisch(shares, markups, passthroughs)

	# Comparison of oligopoly model passthroughs to empirical distribution from Amiti et al
	compare_passthrough_to_amiti_empirics(firms, prices, shares, markups, passthroughs)

	# Figures 3-4 from Darwinian Returns to Scale, using data from oligopoly model
	plot_darwinian_distributions(firms, shares, markups, passthroughs)
	calvo_wage_slopes,real_rigid_wage_slopes,phil_wage_slopes,calvo_price_slopes,real_rigid_price_slopes,phil_price_slopes = philips_slopes(shares, markups, passthroughs, frisch=0.2, inc_elast=0.2, pi=0.5)
	print(calvo_wage_slopes,real_rigid_wage_slopes,phil_wage_slopes,calvo_price_slopes,real_rigid_price_slopes,phil_price_slopes)
	print('Real rigid, wage: ', calvo_wage_slopes/real_rigid_wage_slopes)
	print('Misallocation, wage: ', real_rigid_wage_slopes/phil_wage_slopes)
	print('Real rigid, price: ', calvo_price_slopes/real_rigid_price_slopes)
	print('Misallocation, price: ', real_rigid_price_slopes/phil_price_slopes)
	plt.show()

	#########################################
	# Darwinian returns to scale data
	shares_dar,markups_dar,passthroughs_dar,prod_dar,output_dar = sp.get_key_darwinian_data()
	markups_dar = sp.generate_markups(passthroughs_dar, shares_dar, 1.15)
	# Print flattening
	calvo_wage_slopes,real_rigid_wage_slopes,phil_wage_slopes,calvo_price_slopes,real_rigid_price_slopes,phil_price_slopes = philips_slopes_kimball(shares_dar, markups_dar, passthroughs_dar, frisch=0.2, inc_elast=0.2, pi=0.5)
	print(calvo_wage_slopes,real_rigid_wage_slopes,phil_wage_slopes,calvo_price_slopes,real_rigid_price_slopes,phil_price_slopes)
	print('Real rigid, wage: ', calvo_wage_slopes/real_rigid_wage_slopes)
	print('Misallocation, wage: ', real_rigid_wage_slopes/phil_wage_slopes)
	print('Real rigid, price: ', calvo_price_slopes/real_rigid_price_slopes)
	print('Misallocation, price: ', real_rigid_price_slopes/phil_price_slopes)
	# Slope of the Philips curve, varying parameters in turn
	vary_pi(shares_dar, markups_dar, passthroughs_dar, frisch_fix=0.2, inc_elast_fix=0.2, philips_fn=philips_slopes_kimball)
	vary_frisch(shares_dar, markups_dar, passthroughs_dar, philips_fn=philips_slopes_kimball)
	#vary_mu_bar(shares_dar, passthroughs_dar, philips_fn=philips_slopes_kimball, savefile=None)

	# Replication of figures 3-4 from Darwinian Returns to Scale
	plot_darwinian_distributions(prod_dar, shares_dar, markups_dar, passthroughs_dar)

	output = calculate_output(shares, prices, N_firms=45, sigma_within=10, sigma_across=1)
	#compare_passthrough_distributions(output, passthroughs, output_dar, passthroughs_dar)

	plt.show()

driver()
kw_version()
#tikz_plots_only()

def generate_markups_mu0(rho, shares, mu_0):
	theta = np.linspace(0, 1, len(rho)+2)[1:-1]
	rho_fn = interp1d(np.append([0], theta),np.append([1], rho))
	dloglambda = np.diff(np.log(shares)) / np.diff(theta)
	dloglambda_fn = interp1d(np.append([0], theta), np.append([dloglambda[0], dloglambda[0]], dloglambda))
	dmu = lambda theta,mu: mu*(mu-1)*(1-rho_fn(theta))/rho_fn(theta) * dloglambda_fn(theta)
	mu = sp.rk4_integrate(mu_0, theta, dmu)
	return mu

def plot_changes_due_to_aggmarkup():
	shares_dar,markups_dar,passthroughs_dar,prod_dar,output_dar = sp.get_key_darwinian_data(pull_col=2)
	mu_0_arr = np.linspace(1.0, 1.007, 40)
	#agg_markup_arr = [1.05,1.15]
	agg_markup = np.zeros(len(mu_0_arr))
	elastrho_cov = np.zeros(len(mu_0_arr))
	elastrho_corr = np.zeros(len(mu_0_arr))
	e_elast = np.zeros(len(mu_0_arr))
	emp_last = np.zeros(len(mu_0_arr))
	num_nonmonotonic = np.zeros(len(mu_0_arr))
	for i,mu_0 in enumerate(mu_0_arr):
		markups_dar = generate_markups_mu0(passthroughs_dar, shares_dar, mu_0)
		agg_markup[i] = 1/(sp.weighted_exp(1/markups_dar, shares_dar))
		elast = 1/(1-1/markups_dar)
		emp = shares_dar/markups_dar
		emp_last[i] = np.diff(emp)[-1]
		num_nonmonotonic[i] = sum(np.diff(emp) < 0) / len(emp)
		e_elast[i] = sp.weighted_exp(elast, shares_dar)
		elastrho_cov[i] = sp.weighted_cov(elast, passthroughs_dar, shares_dar)
		elastrho_corr[i] = elastrho_cov[i]/np.sqrt(sp.weighted_cov(elast, elast, shares_dar) * sp.weighted_cov(passthroughs_dar, passthroughs_dar, shares_dar))
		print('Aggregate markup: ', agg_markup[i], emp_last[i], elastrho_cov[i], elastrho_corr[i])
	print(e_elast)
	plt.figure()
	plt.plot(agg_markup,elastrho_cov)
	plt.xlabel('Average markup')
	plt.title('$Cov_\\lambda[\\sigma_\\theta, \\rho_\\theta]$')
	plt.figure()
	plt.plot(agg_markup,elastrho_cov/e_elast)
	plt.xlabel('Average markup')
	plt.title('$Cov_\\lambda[\\sigma_\\theta, \\rho_\\theta]/E_\\lambda[\\sigma_\\theta]$')
	plt.ylim([0.282,0.283])
	plt.figure()
	plt.plot(agg_markup,elastrho_corr)
	plt.xlabel('Average markup')
	plt.ylim([0.994,0.995])
	plt.title('$Corr_\\lambda[\\sigma_\\theta, \\rho_\\theta]$')
	plt.figure()
	plt.plot(agg_markup,emp_last)
	plt.xlabel('Average markup')
	plt.axhline(y=0, color='black', linewidth=0.5)
	plt.title('Monotonic? (If greater than zero)')
	plt.figure()
	plt.plot(agg_markup,num_nonmonotonic)
	plt.xlabel('Average markup')
	plt.axhline(y=0, color='black', linewidth=0.5)
	plt.title('Percent of distribution nonmonotonic')
	plt.show()

